{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trajectory Tracking with Integral Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "from torchdyn.models import *; from torchdyn import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Integral Loss\n",
    "\n",
    "We consider a loss of type\n",
    "\n",
    "$$\n",
    "    \\ell_\\theta := \\int_0^S g(z(s))d\\tau\n",
    "$$\n",
    "\n",
    "where $z(s)$ satisfies the neural ODE initial value problem \n",
    "\n",
    "$$\n",
    "    \\left\\{\n",
    "    \\begin{aligned}\n",
    "        \\dot{z}(s) &= f(z(s), \\theta(s))\\\\\n",
    "        z(0) &= x\n",
    "    \\end{aligned}\n",
    "    \\right. \\quad s\\in[0,S]\n",
    "$$\n",
    "\n",
    "where $\\theta(s)$ is parametrized with some spectral method (`Galerkin-style`),  i.e. $\\theta(s)=\\theta(s,\\omega)$ [$\\omega$: parameters of $\\theta(s)$].\n",
    "\n",
    "**REMARK:** In `torchdyn`, we do not need to evaluate the following integral in the forward pass of the ODE integration.\n",
    "In fact, we will compute the gradient $d\\ell/d\\omega$ just by solving backward \n",
    "\n",
    "$$\n",
    "    \\begin{aligned}\n",
    "        \\dot\\lambda (s) &= -\\frac{\\partial f}{\\partial z}\\lambda(s) + \\frac{\\partial g}{\\partial z} \\\\\n",
    "        \\dot\\mu (s) &= -\\frac{\\partial f}{\\partial \\theta}\\frac{\\partial \\theta}{\\partial \\omega}\\lambda(s)\n",
    "    \\end{aligned}~~\\text{with}~~\n",
    "    \\begin{aligned}\n",
    "        \\lambda (S) &= 0\\\\\n",
    "        \\mu (S) &= 0\n",
    "    \\end{aligned}\n",
    "$$\n",
    "\n",
    "and, $\\frac{d\\ell}{d\\omega} = \\mu(0)$. Check out [this paper](https://arxiv.org/abs/2003.08063) for more details on the integral adjoint for Neural ODEs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example:** Use a neural ODE to track a 2D curve $\\gamma:[0,\\infty]\\rightarrow \\mathbb{S}_2$ ($\\mathbb{S}_2$: unit circle, $\\mathbb{S}_2:=\\{x\\in\\mathbb{R}^2:||x||_2=1\\}$), i.e.\n",
    "\n",
    "$$\n",
    "    \\gamma(s) := [\\cos(2\\pi s), \\sin(2\\pi s)]\n",
    "$$\n",
    "\n",
    "which has periodicity equal to $1$: $\\forall n\\in\\mathbb{N}, \\forall s\\in[0,\\infty]~~\\gamma(s) = \\gamma(ns)$.\n",
    "\n",
    "Let suppose to train the neural ODE for $s\\in[0,1]$. Therefore we can easily setup the integral cost as\n",
    "\n",
    "$$\n",
    "    \\ell_\\theta := \\int_0^1 (h(\\tau)-\\gamma(\\tau))^2 d\\tau\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IntegralLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def y(self, s):\n",
    "        return torch.tensor([torch.cos(2*np.pi*s), torch.sin(2*np.pi*s)])[None, :].to(device)\n",
    "        \n",
    "    def forward(self, s, x):\n",
    "        int_loss = ((self.y(s)-x)**2).mean()\n",
    "        print(f'\\rIntegral loss: {int_loss} s: {s}', end='')\n",
    "        return int_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** In this case we do not define any dataset of initial conditions and load it into the dataloader. Instead, at each step, we will sample new ICs from a normal distrubution centered in $\\gamma(0) = [1,0]$\n",
    "\n",
    "$$\n",
    "    x_t \\sim \\mathcal{N}(\\gamma(0),0.1)\n",
    "$$\n",
    "\n",
    "However, we still need to define a \"dummy\" `trainloader` to \"trick\" `pythorch_lightning`'s API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dummy trainloader\n",
    "train = torch.utils.data.TensorDataset(torch.zeros(1), torch.zeros(1))\n",
    "trainloader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Learner**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "class Learner(pl.LightningModule):\n",
    "    def __init__(self, model:nn.Module):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.model(x)    \n",
    "    \n",
    "    def training_step(self, batch, batch_idx):    \n",
    "        # We sample from Normal distribution around \"nominal\" initial condition\n",
    "        x = torch.tensor([1.,0.])\n",
    "        x = x + 0.5*torch.randn(4096, 2)\n",
    "        y_hat = self.model(x.to(device))\n",
    "        # We need to evaluate a \"dummy loss\" (just the summed output of the model)\n",
    "        # to construct the graph for the 'backward()' \"triggering\" the integral adjoint \n",
    "        loss = 0.*y_hat.sum()\n",
    "        tensorboard_logs = {'train_loss': loss}\n",
    "        # At traing we expect `loss` to be 1\n",
    "        return {'loss': loss, 'log': tensorboard_logs}   \n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return trainloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter Varying Neural ODE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use a Galerkin Neural ODE with one hidden layer, \n",
    "# Fourier spectrum (period=1) and only 2 freq.s\n",
    "f = nn.Sequential(DepthCat(1),\n",
    "                  GalLinear(2, 64, \n",
    "                            FourierExpansion, \n",
    "                            n_harmonics=2,\n",
    "                            dilation=False,\n",
    "                            shift=False),\n",
    "                  nn.Tanh(),\n",
    "                  nn.Linear(64, 2))\n",
    "\n",
    "# Define the model\n",
    "model = NeuralDE(f, \n",
    "                 solver='dopri5',\n",
    "                 sensitivity='adjoint',\n",
    "                 atol=1e-6,\n",
    "                 rtol=1e-6,\n",
    "                 intloss=IntegralLoss()).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type     | Params\n",
      "-----------------------------------\n",
      "0 | model | NeuralDE | 898   \n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b93e2c5db5f4dcb88575a68689640db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Integral loss: 0.31433892250061035 s: -0.01702913455665111591\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train the model\n",
    "learn = Learner(model)\n",
    "trainer = pl.Trainer(min_epochs=1500, max_epochs=1500, gpus=1)\n",
    "trainer.fit(learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Plots**\n",
    "\n",
    "We first evaluate the model on a test set of initial conditions sampled from $\\mathcal{N}(\\gamma(0),\\sigma\\mathbb{I})$. Since we trained the neural ODE with $\\sigma=0.1$, now we test it with $\\sigma=0.2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_span = torch.linspace(0, 5, 500)\n",
    "x_test = torch.tensor([1.,0.])\n",
    "x_test = x_test + 0.2*torch.randn(100, 2)\n",
    "trajectory = model.trajectory(x_test.to(device), s_span).detach().cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Depth evolution of the system** -> The system is trained for $s\\in[0,1]$ and then we extrapolate until $s=5$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y1 = np.cos(2*np.pi*s_span)\n",
    "y2 = np.sin(2*np.pi*s_span)\n",
    "\n",
    "fig = plt.figure(figsize=(8,3))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.scatter(s_span[:100], y1[:100], color='orange',alpha=0.5)\n",
    "ax.scatter(s_span[:100], y2[:100], color='orange',alpha=0.5)\n",
    "ax.scatter(s_span[100:], y1[100:], color='red',alpha=0.5)\n",
    "ax.scatter(s_span[100:], y2[100:], color='red',alpha=0.5)\n",
    "for i in range(len(x_test)):\n",
    "    ax.plot(s_span, trajectory[:,i,:], color='blue', alpha=.1)\n",
    "ax.set_xlabel(r\"$s$ [depth]\")\n",
    "ax.set_ylabel(r\"$z(s)$ [state]\")\n",
    "ax.set_title(r\"Depth evolution of the system\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**State-space trajectories**\n",
    "-> All the (random) IC converge to the desired trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(3,3))\n",
    "ax = fig.add_subplot(111)\n",
    "for i in range(len(x_test)):\n",
    "    ax.plot(trajectory[:,i,0], trajectory[:,i,1], color='blue', alpha=.1)\n",
    "    ax.scatter(trajectory[0,i,0], trajectory[0,i,1], color='black', alpha=.1)\n",
    "ax.plot(y1[:100],y2[:100], color='red')\n",
    "ax.set_xlabel(r\"$z_0$\")\n",
    "ax.set_ylabel(r\"$z_1$\")\n",
    "ax.set_title(r\"State-Space Trajectories\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
